# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass
import copy
import wandb
import torch
import time
import pickle
from os.path import join
from datetime import datetime

from utils import utils
from configs.datasets_config import get_dataset_info
from configs.parse_args import parse_args
from utils.utils import assert_correctly_masked
from utils import utils as da_utils

from qm9 import dataset as qmds
from qm9.utils import prepare_context, compute_mean_mad

from dadm import da_diffusion
from train_test import train_epoch, test, analyze_and_save
from dadm.get_models import get_optim, get_autoencoder, get_da_diffusion

args = parse_args()

args.wandb_usr = utils.get_wandb_username(args.wandb_usr)

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
dtype = torch.float32

if args.resume is not None:
    exp_name = args.exp_name + '_resume'
    start_epoch = args.start_epoch
    resume = args.resume
    wandb_usr = args.wandb_usr
    normalization_factor = args.normalization_factor
    aggregation_method = args.aggregation_method

    with open(join(args.resume, 'args.pickle'), 'rb') as f:
        args = pickle.load(f)

    args.resume = resume
    args.break_train_epoch = False

    args.exp_name = exp_name
    args.start_epoch = start_epoch
    args.wandb_usr = wandb_usr

    # Careful with this -->
    if not hasattr(args, 'normalization_factor'):
        args.normalization_factor = normalization_factor
    if not hasattr(args, 'aggregation_method'):
        args.aggregation_method = aggregation_method

    print(args)

# Get the current date and time
current_datetime = datetime.now()
save_dir = './Models/' + args.dataset + '_'
utils.create_folders(args, save_dir)
log = open(save_dir + 'outputs/%s/log.txt' % args.exp_name, mode='w+')
print(f'Train on dataset {args.dataset}', file=log)
print(f'Train on dataset {args.dataset}')
# print(args)

# Wandb config
if args.no_wandb:
    mode = 'disabled'
else:
    mode = 'online' if args.online else 'offline'
kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': 'ood_diffusion_mg', 'config': args,
          'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode}
wandb.init(**kwargs)
wandb.save('*.txt')

dataloaders, charge_scale = qmds.retrieve_dataloaders(args)

data_dummy = next(iter(dataloaders['train']))

if len(args.conditioning) > 0:
    print(f'Conditioning on {args.conditioning}')
    log.write(f'Conditioning on {args.conditioning} \n')
    property_norms = compute_mean_mad(dataloaders, args.conditioning, args.dataset)
    context_dummy = prepare_context(args.conditioning, data_dummy, property_norms)
    context_node_nf = context_dummy.size(2)
else:
    context_node_nf = 0
    property_norms = None
    # Align VAE output size and self condition size
    args.self_condition_nf = args.latent_nf

args.context_node_nf = context_node_nf

dataset_info = get_dataset_info(args.dataset, args.remove_h)
args.ood_element_size = len(dataset_info['atom_decoder'])
n_nodes_info = dataset_info['n_nodes']

# Create Latent Diffusion Model or Audoencoder
if args.train_diffusion:
    model, nodes_dist, prop_dist = get_da_diffusion(args, device, dataloaders, n_nodes_info)
else:
    model, nodes_dist, prop_dist = get_autoencoder(args, device, dataset_info, dataloaders)

if prop_dist is not None:
    prop_dist.set_normalizer(property_norms)

model = model.to(device)
optim = get_optim(args, model)
# print(model)

gradnorm_queue = utils.Queue()
gradnorm_queue.add(3000)  # Add large value that will be flushed.

def check_mask_correct(variables, node_mask):
    for variable in variables:
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def main():
    if args.resume is not None:
        flow_state_dict = torch.load(join(args.resume, 'generative_model.npy'))
        optim_state_dict = torch.load(join(args.resume, 'optim.npy'))
        model.load_state_dict(flow_state_dict)
        optim.load_state_dict(optim_state_dict)

    # Initialize dataparallel if enabled and possible.
    if args.dp and torch.cuda.device_count() > 1:
        print(f'Training using {torch.cuda.device_count()} GPUs')
        model_dp = torch.nn.DataParallel(model.cpu())
        model_dp = model_dp.cuda()
    else:
        model_dp = model

    # Initialize model copy for exponential moving average of params.
    if args.ema_decay > 0:
        model_ema = copy.deepcopy(model)
        ema = da_utils.EMA(args.ema_decay)

        if args.dp and torch.cuda.device_count() > 1:
            model_ema_dp = torch.nn.DataParallel(model_ema)
        else:
            model_ema_dp = model_ema
    else:
        ema = None
        model_ema = model
        model_ema_dp = model_dp

    best_nll_val = 1e8
    best_nll_test = 1e8
    for epoch in range(args.start_epoch, args.n_epochs):
        start_epoch = time.time()
        train_epoch(args=args, loader=dataloaders['train'], epoch=epoch, model=model, model_dp=model_dp,
                    model_ema=model_ema, ema=ema, device=device, dtype=dtype, property_norms=property_norms,
                    nodes_dist=nodes_dist, dataset_info=dataset_info,
                    gradnorm_queue=gradnorm_queue, optim=optim, prop_dist=prop_dist)
        print(f"Epoch took {time.time() - start_epoch:.1f} seconds.")

        if epoch % args.test_epochs == 0:
            if isinstance(model, da_diffusion.DomainAdaptiveDiffusion):
                wandb.log(model.log_info(), commit=True)

            if not args.break_train_epoch and args.train_diffusion and args.analyze_during_train:
                analyze_and_save(args=args, epoch=epoch, model_sample=model_ema, nodes_dist=nodes_dist,
                                 dataset_info=dataset_info, device=device,
                                 prop_dist=prop_dist, n_samples=args.n_stability_samples)
            nll_val = test(args=args, loader=dataloaders['valid'], epoch=epoch, eval_model=model_ema_dp,
                           partition='Val', device=device, dtype=dtype, nodes_dist=nodes_dist,
                           property_norms=property_norms)
            nll_test = test(args=args, loader=dataloaders['test'], epoch=epoch, eval_model=model_ema_dp,
                            partition='Test', device=device, dtype=dtype,
                            nodes_dist=nodes_dist, property_norms=property_norms)

            if nll_val < best_nll_val:
                best_nll_val = nll_val
                best_nll_test = nll_test
                if args.save_model:
                    args.current_epoch = epoch + 1
                    utils.save_model(optim, save_dir + 'outputs/%s/optim.npy' % args.exp_name)
                    utils.save_model(model, save_dir + 'outputs/%s/generative_model.npy' % args.exp_name)
                    if args.ema_decay > 0:
                        utils.save_model(model_ema, save_dir + 'outputs/%s/generative_model_ema.npy' % args.exp_name)
                    with open(save_dir + 'outputs/%s/args.pickle' % args.exp_name, 'wb') as f:
                        pickle.dump(args, f)
            print('Val loss: %.4f \t Test loss:  %.4f' % (nll_val, nll_test))
            print('Best val loss: %.4f \t Best test loss:  %.4f' % (best_nll_val, best_nll_test))
            log.write('Val loss: %.4f \t Test loss:  %.4f \n' % (nll_val, nll_test))
            log.write('Best val loss: %.4f \t Best test loss:  %.4f \n' % (best_nll_val, best_nll_test))
            wandb.log({"Val loss ": nll_val}, commit=True)
            wandb.log({"Test loss ": nll_test}, commit=True)
            wandb.log({"Best cross-validated test loss ": best_nll_test}, commit=True)
        if args.save_model and epoch % 10 == 0:
            optim_path = save_dir + 'outputs/%s/optim_%d.npy' % (args.exp_name, epoch)
            model_path = save_dir + 'outputs/%s/generative_model_%d.npy' % (args.exp_name, epoch)
            utils.save_model(optim, optim_path)
            utils.save_model(model, model_path)
            if args.ema_decay > 0:
                utils.save_model(model_ema, save_dir + 'outputs/%s/generative_model_ema_%d.npy' % (args.exp_name, epoch))
            with open(save_dir + 'outputs/%s/args_%d.pickle' % (args.exp_name, epoch), 'wb') as f:
                pickle.dump(args, f)
            print('Model saved (%d).' % (epoch))


if __name__ == "__main__":
    main()
